import os
import gc
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data.dataset import TensorDataset

def load_data(path, suffix, batch_size = 128):
    loc_train, loc_test, vel_train, vel_test, edges_train, edges_test = train_test_split(
        np.load(os.path.join(path, 'loc_' + suffix + '.npy')),
        np.load(os.path.join(path, 'vel_' + suffix + '.npy')),
        np.load(os.path.join(path, 'edges_' + suffix + '.npy')),
        test_size=0.2, random_state=42)

    # [num_samples, num_timesteps, num_dims, num_atoms]
    num_atoms = loc_train.shape[3]

    loc_max = loc_train.max()
    loc_min = loc_train.min()
    vel_max = vel_train.max()
    vel_min = vel_train.min()

    # Normalize to [-1, 1]
    loc_train = (loc_train - loc_min) * 2 / (loc_max - loc_min) - 1
    vel_train = (vel_train - vel_min) * 2 / (vel_max - vel_min) - 1

    loc_test = (loc_test - loc_min) * 2 / (loc_max - loc_min) - 1
    vel_test = (vel_test - vel_min) * 2 / (vel_max - vel_min) - 1

    # [num_sims, num_timesteps, num_dims, num_atoms]
    # Reshape to: [num_sims, num_atoms, num_timesteps * num_dims]
    loc_train = loc_train.reshape(*loc_train.shape[:2], -1)
    vel_train = vel_train.reshape(*vel_train.shape[:2], -1)
    feat_train = np.concatenate([loc_train, vel_train], axis=-1)
    edges_train = np.reshape(edges_train, [-1, num_atoms ** 2])
    edges_train = np.array((edges_train + 1) / 2, dtype=np.int64)

    loc_test = loc_test.reshape(*loc_test.shape[:2], -1)
    vel_test = vel_test.reshape(*vel_test.shape[:2], -1)
    feat_test = np.concatenate([loc_test, vel_test], axis=-1)
    edges_test = np.reshape(edges_test, [-1, num_atoms ** 2])
    edges_test = np.array((edges_test + 1) / 2, dtype=np.int64)

    feat_train = torch.FloatTensor(feat_train)
    edges_train = torch.FloatTensor(edges_train)
    feat_test = torch.FloatTensor(feat_test)
    edges_test = torch.FloatTensor(edges_test)

    off_diag_idx = np.ravel_multi_index(
        np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)),
        [num_atoms, num_atoms])
    edges_train = edges_train[:, off_diag_idx]
    edges_test = edges_test[:, off_diag_idx]

    train_data = TensorDataset(feat_train, edges_train)
    test_data = TensorDataset(feat_test, edges_test)

    train_data_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data_loader = Data.DataLoader(test_data, batch_size=batch_size)

    return train_data_loader, test_data_loader

class CTRNN(nn.Module):
    def __init__(self, device, input_size: int, hidden_size: int, output_size: int, num_unfolds: int = 3, tau: int = 1):
        super(CTRNN, self).__init__()
        self.device = device
        self.units = hidden_size
        self.state_size = hidden_size
        self.num_unfolds = num_unfolds
        self.tau = tau
        
        self.kernel = nn.Linear(input_size, hidden_size)
        self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False)
        self.scale = nn.Parameter(torch.ones(hidden_size))

        self.decoder = nn.Linear(hidden_size, output_size)


    def forward(self, t, x):
        # t.shape == (batch_size, n_take)
        # x.shape == (batch_size, n_take, input_size)
        hidden_state = torch.zeros((t.size(0), self.units)).to(self.device)
        for i in range(t.size(1)):
            delta_t = t[:, i] / self.num_unfolds
            for _ in range(self.num_unfolds):
                hidden_state = self.euler(x[:, i, :], hidden_state, delta_t)
        return self.decoder(hidden_state)

    def dfdt(self, inputs, hidden_state):
        dh_in = self.scale * (self.kernel(inputs) + self.recurrent_kernel(hidden_state)).tanh()
        if self.tau > 0:
            dh = dh_in - hidden_state * self.tau
        else:
            dh = dh_in
        return dh

    def euler(self, inputs, hidden_state, delta_t):
        return hidden_state + (delta_t * self.dfdt(inputs, hidden_state).permute(1, 0)).permute(1, 0)

def main(version, hidden_size):
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    batch_size = 128
    num_epochs = 25
    n_take = 49

    train_loader, test_loader = load_data('spring', 'springs5', batch_size)

    print(device)

    model = CTRNN(device, input_size=20, hidden_size=hidden_size, output_size=20).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader), eta_min = 0.00005, last_epoch = -1)

    criterion = nn.BCEWithLogitsLoss()

    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(num_epochs):
        print(version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.train()
        for x, y in tqdm(train_loader):
            t = torch.ones((x.size(0), n_take)).to(device)
            x, y = x.to(device), y.to(device)
            output = model(t, x)
            loss = criterion(output, y)
            loss.backward()

            epoch_corrects += int(torch.sum((output > 0).int() == y))
            epoch_loss += loss.item() * x.size(0)
            num_sample += x.size(0) * y.size(1)

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        train_loss.append(epoch_loss / num_sample)
        train_acc.append(epoch_corrects / num_sample)
        print(' ', train_loss[-1], train_acc[-1])

        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.eval()
        with torch.no_grad():
            for x, y in tqdm(test_loader):
                t = torch.ones((x.size(0), n_take)).to(device)
                x, y = x.to(device), y.to(device)
                output = model(t, x)
                loss = criterion(output, y)

                epoch_corrects += int(torch.sum((output > 0).int() == y))
                epoch_loss += loss.item() * x.size(0)
                num_sample += x.size(0) * y.size(1)

        test_loss.append(epoch_loss / num_sample)
        test_acc.append(epoch_corrects / num_sample)
        print(' ', test_loss[-1], test_acc[-1])

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train.csv')
        except:
            print('Fail to save the file Train.csv')
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test.csv')
        except:
            print('Fail to save the file Test.csv')
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [128, 256, 512]:
        file_version = f'CTRNN_Spring_{hidden_size}'
        main(file_version, hidden_size)